package org.bjf.aop;

import com.alibaba.fastjson.JSONObject;
import java.util.Date;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.bjf.exception.CommMsgCode;
import org.bjf.exception.ServiceException;
import org.bjf.modules.core.web.core.LoginInfo;
import org.bjf.modules.core.web.core.ThreadContext;
import org.bjf.modules.user.enums.UserRedisKey;
import org.bjf.utils.RedisUtil;
import org.bjf.utils.TokenUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

/**
 * websocket 拦截器
 *
 * @author bjf
 */
@Slf4j
@Component
public class WebsocketInterceptor implements HandshakeInterceptor {

  @Autowired
  private RedisUtil redis;

  @Override
  public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
      WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
    // token校验
    String accessToken = getAccessToken(request);
    if (StringUtils.isBlank(accessToken) || !TokenUtil.verifyToken(accessToken)) {
      log.error("invalid websocket token:{}", accessToken);
      throw new ServiceException(CommMsgCode.UNAUTHORIZED);
    }
    LoginInfo loginInfo = getLoginInfo(accessToken);
    // 登录用户信息放到ThreadLocal
    loginInfo.setLastTime(new Date());
    ThreadContext.setLoginInfo(loginInfo);

    log.info("websocket loginInfo：" + JSONObject.toJSONString(loginInfo));

    return Boolean.TRUE;
  }

  @Override
  public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
      WebSocketHandler wsHandler, Exception exception) {

  }

  private LoginInfo getLoginInfo(String accessToken) {
    // redis取登录用户信息
    String userKey = UserRedisKey.TOKEN_API.as(accessToken);
    LoginInfo loginInfo = redis.get(userKey, LoginInfo.class);
    if (loginInfo == null) {
      throw new ServiceException(CommMsgCode.UNAUTHORIZED);
    }
    // 更新reids过期时间
    redis.setex(userKey, loginInfo, 86400 * 15);

    return loginInfo;
  }

  private String getAccessToken(ServerHttpRequest request) {
    String accessToken = null;
    if (request instanceof ServletServerHttpRequest) {
      ServletServerHttpRequest servletServerHttpRequest = (ServletServerHttpRequest) request;
      HttpServletRequest servletRequest = servletServerHttpRequest.getServletRequest();
      accessToken = servletRequest.getHeader("x-websocket-token");
      if (StringUtils.isBlank(accessToken)) {
        accessToken = servletRequest.getParameter("x-websocket-token");
      }
    }

    return accessToken;
  }

}
